iT邦幫忙

2022 iThome 鐵人賽

DAY 5
0
AI & Data

JAX 好好玩系列 第 5

JAX 好好玩 (5) : JAX.NUMPY (1) : 一個更好的 Numpy

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

在 JAX 官方的教學網頁 JAX Quickstart [5.1] 上,開宗明義就說:

JAX 是一個可以跑在 CPU, GPU, 以及 TPU 上的 Numpy, …
(JAX is NumPy on the CPU, GPU, and TPU, … )

怎麼做到的呢?JAX 提供了與 Numpy 幾乎完全一致的 API !這些 API 是在 jax.numpy 之下,而習慣上我們是這樣使用的:

import jax.numpy as jnp

它對應了 Numpy 的習慣用法:

import numpy as np

在大部份的情況下,jax.numpy 的 API 用法,和標準 Numpy API 的用法相同,僅有少數的例外,這些例外,老頭會在日後加以說明。

現在我們先來看看一些簡單的例子:

# import jax.numpy and numpy
import jax.numpy as jnp
import numpy as np

# declare the data 
#==========================================================================
# jax.numpy
x = jnp.arange(10)
# numpy
y = np.arange(10)

# show the declared data 
#==========================================================================
print(f'Data defined by JAX : {x}')
print(f'Data defined by Numpy : {y}')

output:
Data defined by JAX : [0 1 2 3 4 5 6 7 8 9]
Data defined by Numpy : [0 1 2 3 4 5 6 7 8 9]

# operation: sum 
#==========================================================================
# jax.numpy
sum_jnp = jnp.sum(x)
# numpy
sum_np = np.sum(y)

print(f'Sum by JAX : {sum_jnp}')
print(f'Sum by Numpy : {sum_np}')

output:
Sum by JAX : 45
Sum by Numpy : 45

# operation: dot
#==========================================================================
# jax.numpy
dot_jnp = jnp.dot(x,x)
# numpy
dot_np = np.dot(y,y)

print(f'Dot by JAX : {dot_jnp}')
print(f'Dot by Numpy : {dot_np}')

output:
Dot by JAX : 285
Dot by Numpy : 285

在以上的例子,jax.numpy 的 API 和 Numpy 的 API 幾乎是一對一的對應,語法 (syntax) 和語意 (semantics) 也完全相同。

在繼續談 jax.numpy 之前,老頭想先簡單介紹 %timeit 這個魔術指令 (magic command) [5.2]。%timeit 可以計算一列 Python 敍述 (statement) 執行時所需要的時間。例如:

%timeit  100.0 / 5.0

output: 11.2 ns ± 0.0468 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)

%timeit 採用了「二層重覆執行」的方式,來得到精確且有統計意義的結果。

第一層:執行 Python 敍述 N 次 (N 個 loops),計算平均時間,得到 Python 敍述執行一次的時間值。
第二層:重複第一層 R 次 (R 個 runs),得到 R 個結果,利用這 R 個結果,可以得到平均數、標準差等統計量。

用選項 -n 來指定第一層的次數,用 -r 來指定第二層的次數,沒有指定的話,則選用既定值 (default values)。

-n : default value 由系統自行判斷要達要適當精確度所需的重覆次數。
-r : default value 7

讀者可能會納悶,為什麼需要第一層的重覆呢?因為由作業系統 (operation system; OS) 所提供的最小計時單位,可能會比一些簡單的敍述所需的時間大得多,如上面的例子,100.0 / 5.0 所需的時間僅需 11.2 ns,這不是作業系統計時器所可以量出來的,所以 %timeit 必須重覆執行它 100000000 loops 才能正確的估算它的執行時間。

做個實驗

%timeit -n 1 -r 1 100.0/5.0

output:
668 ns ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

這個實驗是在 google colab 的 linux 虛擬機 (VM) 上執行的,版本是 5.4.188+

!uname -a

output:
Linux 47cad5f60b70 5.4.188+ #1 SMP Sun Apr 24 10:03:06 PDT 2022 x86_64 x86_64 x86_64 GNU/Linux

linux 能夠支援的計時精度,若導入了高解析度計時功能 (HRT; high-resolution timer)大約只到微秒 (micro-second) 等級,因此只執行 100.0/5.0 所得到結果比實際所需執行時間大得多!

希望這個簡單的介紹能讓大家了解它的基本用法,接下來 %timeit 會被用來計算 jax.numpy 及 Numpy 執行同一個運算所需要的時間,讓大家看看 jax.numpy 到底可以快多少!

註:

[5.1] 參考 JAX Quickstart

[5.2] 老頭在此不對「魔力指令」多加說明,讀者們可以參考 IPython 的官方文件 Build-in magic commands


上一篇
JAX 好好玩 (4) : JAX 是什麼 ? 概說
下一篇
JAX 好好玩 (6) : JAX.NUMPY (2) : 虛擬亂數產生器
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言